#include "airpotentials.h"
#include "cstdlib"
#include "ctime"


int main(int argc, char **argv)
{
	// Models
	airpotentials model[3][2];

	// Init models (from file)
	model[0][0].LoadModelFromFile("m1 - N2N2.txt");
	model[1][0].LoadModelFromFile("m1 - N2O2.txt");
	model[2][0].LoadModelFromFile("m1 - O2O2.txt");

	model[0][1].LoadModelFromFile("m2 - N2N2.txt");
	model[1][1].LoadModelFromFile("m2 - N2O2.txt");
	model[2][1].LoadModelFromFile("m2 - O2O2.txt");

	double m1x[3] = {0,0,0};
	double x[4][3];
	double conf[14];
	double comptime[3][2];
	
	printf("-------------- Air PES test -----------------\n");

	FILE *dftdata;
	FILE *pr;
	FILE *pr2;

	// Output files
	char *pair[3] = {"N2-N2","N2-O2","O2-O2"};
	char *pairf[3] = {"N2-N2.txt","N2-O2.txt","O2-O2.txt"};
	char *pairfw[3] = {"PES-N2N2.txt","PES-N2O2.txt","PES-O2O2.txt"};
	char *pairfa[3] = {"Ang-N2N2-m1.txt","Ang-N2O2-m1.txt","Ang-O2O2-m1.txt"};
	char *pairfa2[3] = {"Ang-N2N2-m2.txt","Ang-N2O2-m2.txt","Ang-O2O2-m2.txt"};

	printf("\n--- 1. Energy computation test ---\n");
	pr2=fopen("U(R)DFT.csv","w");

	for(int pn=0;pn<3;pn++)
	{
		double Rcur, DFTmin, DFTmax;
		int Rn = -1;

		printf("---  1.%d. %s\n",pn+1,pair[pn]);
		fprintf(pr2,"\n %s;\n",pair[pn]);

		dftdata=fopen(pairf[pn],"r");
		pr=fopen(pairfw[pn],"w");

		double diff = 0, diffrel = 0;
		double diff2 = 0, diffrel2 = 0;
		double shift = 300;
		int rows = 0;
		double r0, r1, R, g0, g1, g2, g0r, g1r, g2r, EeV, EK, EH;

		while (!feof(dftdata))  
		{
			if (pn==1) {
				fscanf(dftdata,"%lf %lf %lf %lf %lf %lf %lf %lf %lf",&r0, &r1, &R, &g0, &g1, &g2, &EeV, &EK, &EH);
			} else {
				fscanf(dftdata,"%lf %lf %lf %lf %lf %lf %lf %lf",&r0, &R, &g0, &g1, &g2, &EeV, &EK, &EH);
				r1 = r0;
			}

			g0r = g0 / 180.0 * M_PI;
			g1r = g1 / 180.0 * M_PI;
			g2r = g2 / 180.0 * M_PI;

			// Atoms coordinates
			model[pn][0].CalcAtomCoordByAngles(m1x,r0,r1,R,g0r,g1r,g2r,&x[0][0]);

			double f1[4][3], f2[4][3];
			double EK_model_1 = model[pn][0].ComputeEnergy(x,&f1[0][0]);
			double EK_model_2 = model[pn][1].ComputeEnergy(x,&f2[0][0]);
			rows++;
			diff += (EK_model_1 - EK)*(EK_model_1 - EK);
			diffrel += (EK_model_1 - EK)*(EK_model_1 - EK)/(EK+shift)/(EK+shift);
			diff2 += (EK_model_2 - EK)*(EK_model_2 - EK);
			diffrel2 += (EK_model_2 - EK)*(EK_model_2 - EK)/(EK+shift)/(EK+shift);

			fprintf(pr,"%lf %lf %lf %lf %lf %lf %lf %lf %lf \n ", r0, r1, R, g0, g1, g2, EK, EK_model_1, EK_model_2);

			if (Rn<0){
				Rn=0;
				Rcur = R;
				DFTmin = EK;
				DFTmax = EK;
			}

			if (Rcur!=R){
				fprintf(pr2,"\n%lf;%e;%e;", Rcur, DFTmin, DFTmax);
				Rn++;
				Rcur = R;
				DFTmin = EK;
				DFTmax = EK;
			}

			if (DFTmin > EK) DFTmin = EK;
			if (DFTmax < EK) DFTmax = EK;

			// Validation of the calculation of forces
			double fsum1[3] = {0,0,0}, fsum2[3] = {0,0,0}, msum1[3] = {0,0,0}, msum2[3] = {0,0,0};
			for(int atn=0;atn<4;atn++){
				
				double m1[3], m2[3];
				model[0][0].vectprod(x[atn],f1[atn],m1);
				model[0][0].vectprod(x[atn],f2[atn],m2);

				for(int j=0;j<3;j++)
				{
					fsum1[j] += f1[atn][j];
					fsum2[j] += f2[atn][j];
					msum1[j] += m1[j];
					msum2[j] += m2[j];
				}
			}
			
			for(int j=0;j<3;j++)
			{
				if (fsum1[j]*fsum1[j] > 0.0000001)
				{
					printf("\n !!! Force field error! ");
				}
				if (msum1[j]*msum1[j] > 0.0000001)
				{
					printf("\n !!! Angular force field error! ");
				}
			}
		}

		fprintf(pr2,"\n%lf;%e;%e;", Rcur, DFTmin, DFTmax);

		printf("Rows: %d, mean-square difference (absolute / relative):\n Model 1 (%lf K /  %lf), Model 2 (%lf K /  %lf) \n", rows, sqrt(diff/rows), sqrt(diffrel/rows),sqrt(diff2/rows), sqrt(diffrel2/rows));
	
		fclose(dftdata);
		fclose(pr);
	}
	fclose(pr2);
	

	printf("\n--- 2. Dependence on angles (for figures) ---\n");
	for(int pn=0;pn<3;pn++)
	{
		printf("---  2.%d. %s\n",pn+1,pair[pn]);

		pr=fopen(pairfa[pn],"w");
		pr2=fopen(pairfa2[pn],"w");

		int rows = 0;
		double r0, r1, R, g0, g1=85, g2=90, g0r, g1r, g2r;

		fprintf(pr,"\n Dependences on g0 for dirrerent R\n");
		fprintf(pr2,"\n Dependences on g0  for dirrerent R\n");

		for(R=3.0;R<=8;R+=0.5)  
		{
			fprintf(pr,"%lf %lf %lf : ", R, g1, g2);
			fprintf(pr2,"%lf %lf %lf : ",R, g1, g2);
			for(g0=-90;g0<=90;g0+=2)  
			{
				if (pn==0) {
					r0=0.545260;
					r1=0.545260;
				} else if (pn==1) {
					r0 = 0.545260;
					r1 = 0.602010;
				} else {
					r0 = 0.602010;
					r1 = 0.602010;
				}

				g0r = g0 / 180.0 * M_PI;
				g1r = g1 / 180.0 * M_PI;
				g2r = g2 / 180.0 * M_PI;

				// Atoms coordinates
				model[pn][0].CalcAtomCoordByAngles(m1x,r0,r1,R,g0r,g1r,g2r,&x[0][0]);

				double f1[12], f2[12];
				double EK_model_1 = model[pn][0].ComputeEnergy(x,f1);
				double EK_model_2 = model[pn][1].ComputeEnergy(x,f2);

				fprintf(pr,"%lf ", EK_model_1);
				fprintf(pr2,"%lf ", EK_model_2);
			}
			fprintf(pr,"\n");
			fprintf(pr2,"\n");
		}

		fprintf(pr,"\n Dependences on g1 for R=4.5, g0=-Pi/2 and different g2\n");
		fprintf(pr2,"\n Dependences on g1  for R=4.5, g0=-Pi/2 and different g2\n");

		R=4.5; g0=-90;
		for(g2=0;g2<=90;g2+=10)  
		{
			fprintf(pr,"%lf %lf %lf : ", R, g0, g2);
			fprintf(pr2,"%lf %lf %lf : ",R, g0, g2);

			for(g1=0;g1<=90;g1+=2)  
			{
				if (pn==0) {
					r0=0.545260;
					r1=0.545260;
				} else if (pn==1) {
					r0 = 0.545260;
					r1 = 0.602010;
				} else {
					r0 = 0.602010;
					r1 = 0.602010;
				}

				g0r = g0 / 180.0 * M_PI;
				g1r = g1 / 180.0 * M_PI;
				g2r = g2 / 180.0 * M_PI;

				// Atoms coordinates
				model[pn][0].CalcAtomCoordByAngles(m1x,r0,r1,R,g0r,g1r,g2r,&x[0][0]);

				double f1[12], f2[12];
				double EK_model_1 = model[pn][0].ComputeEnergy(x,f1);
				double EK_model_2 = model[pn][1].ComputeEnergy(x,f2);

				fprintf(pr,"%lf ", EK_model_1);
				fprintf(pr2,"%lf ", EK_model_2);
			}
			fprintf(pr,"\n");
			fprintf(pr2,"\n");
		}
		fclose(pr);
		fclose(pr2);
	}

	printf("\n--- 3. Dependence on R ---");
	pr=fopen("U(R).csv","w");

	for(int pn=0;pn<3;pn++)
	{
		printf("\n\n---  3.%d. %s",pn+1,pair[pn]);
		printf("\n \t \t\tModel 1 \t \t\tModel 2 ");
		printf("\n \tR \t\tMin \t\tMax \t\tMin \t\tMax");

		fprintf(pr,"\n\n%s;",pair[pn]);
		fprintf(pr,"\n ;Model 1; ;Model 2; ; ");
		fprintf(pr,"\nR;Min;Max;Min;Max;");

		double r0, r1, R, g0, g1=85, g2=90, g0r, g1r, g2r;

		for(R=2.0;R<=8;R+=0.5){

			double E_min[2], E_max[2];

			int isfirst = 1;
			double tstart = (double)clock();

			for(g0=-90;g0<=90;g0+=2)  
			for(g1=0;g1<=90;g1+=2)  
			for(g2=0;g2<=90;g2+=2)  
			{
				if (pn==0) {
					r0=0.545260;
					r1=0.545260;
				} else if (pn==1) {
					r0 = 0.545260;
					r1 = 0.602010;
				} else {
					r0 = 0.602010;
					r1 = 0.602010;
				}

				g0r = g0 / 180.0 * M_PI;
				g1r = g1 / 180.0 * M_PI;
				g2r = g2 / 180.0 * M_PI;

				// Atoms coordinates
				model[pn][0].CalcAtomCoordByAngles(m1x,r0,r1,R,g0r,g1r,g2r,&x[0][0]);

				double f1[12], f2[12];
				double EK_model[2];
				EK_model[0] = model[pn][0].ComputeEnergy(x,f1);
				EK_model[1] = model[pn][1].ComputeEnergy(x,f2);

				if (isfirst==1){ 
					E_min[0] = EK_model[0];
					E_min[1] = EK_model[1];
					E_max[0] = EK_model[0];
					E_max[1] = EK_model[1];
					isfirst = 0;
				}

				if (E_min[0]>EK_model[0]) E_min[0] = EK_model[0];
				if (E_min[1]>EK_model[1]) E_min[1] = EK_model[1];
				if (E_max[0]<EK_model[0]) E_max[0] = EK_model[0];
				if (E_max[1]<EK_model[1]) E_max[1] = EK_model[1];
			}
			
			printf("\n \t%lf \t%e \t%e \t%e \t%e",R,E_min[0],E_max[0],E_min[1],E_max[1]);
			fprintf(pr,"\n%lf;%e;%e;%e;%e;",R,E_min[0],E_max[0],E_min[1],E_max[1]);
		}	
	}
	fclose(pr);

	printf("\n\n--- 4. MD computation test ---\n");

	// Bond length and atom mass
	double r1[3] = {0.545260,0.545260,0.60201};
	double r2[3] = {0.545260,0.60201,0.60201};
	double mN = 0.014/6.022/1000;
	double mO = 0.016/6.022/1000;
	double m1[3] = {mN,mN,mO};
	double m2[3] = {mN,mO,mO};

	double xcm[2][3]; // centers of molecules coordinates
	double vcm[2][3]; // centers of molecules velocities
	double r[4][3]; 
	double L[2][3]; // angular momentum

	double T0 = 300;
	int mn = 0;
	double dt = 0.000001;
	double U, U0, E_k, E_k0, E_k_tr, E_k_r;

	for(mn=0;mn<2;mn++){
		printf("\n Model %d.",mn+1);

		for(int it=0;it<3; it++)
		{
			printf("\n---  4.%d.%d. %s\n",mn+1,it+1,pair[it]);
			fprintf(pr,"\n\n\n---  4.%d. %s\n",it+1,pair[it]);

			comptime[it][mn] = -(double)clock();

			// Initial coordinates
			double R0 = 10; 
			double b = 0;
			double R_min = R0, FE_min = 0, FE_max = 0;
			double v0 = 600;

			xcm[0][0] = 0;
			xcm[0][1] = 0;
			xcm[0][2] = 0;
			xcm[1][0] = R0;
			xcm[1][1] = 0;
			xcm[1][2] = b;

			double Phi = 0;
			double Theta = M_PI/4;
			r[0][0] = r1[it]*sin(Theta)*cos(Phi);
			r[0][1] = r1[it]*sin(Theta)*sin(Phi);
			r[0][2] = r1[it]*cos(Theta);

			Phi = M_PI/4;
			Theta = M_PI/2;
			r[2][0] = r2[it]*sin(Theta)*cos(Phi);
			r[2][1] = r2[it]*sin(Theta)*sin(Phi);
			r[2][2] = r2[it]*cos(Theta);

			for(int j=0;j<3;j++){
				r[1][j] = -r[0][j];
				r[3][j] = -r[2][j];
			}

			for(int j=0;j<3;j++)
			{
				x[0][j] = xcm[0][j] + r[0][j];
				x[1][j] = xcm[0][j] + r[1][j];
				x[2][j] = xcm[1][j] + r[2][j];
				x[3][j] = xcm[1][j] + r[3][j];
			}

			// Initial velocities
			vcm[0][0] = v0/2;		
			vcm[0][1] = 0;		
			vcm[0][2] = 0;		
			vcm[1][0] = -v0/2;		
			vcm[1][1] = 0;		
			vcm[1][2] = 0;		

			for(int j=0;j<3;j++){
				L[0][j] = 0;
				L[1][j] = 0;
			}

			double f[4][3];
			U0 = model[it][mn].ComputeEnergy(x,&f[0][0]);

			// Initial E (in K)
			E_k = 0;
			E_k_tr = 0;
			E_k_r = 0;
			for(int j=0;j<3;j++)
			{
					E_k_tr += m1[it]*vcm[0][j]*vcm[0][j];
					E_k_tr += m2[it]*vcm[1][j]*vcm[1][j];
					E_k_r += m1[it]*L[0][j]*L[0][j]/r1[it]/r1[it];
					E_k_r += m2[it]*L[1][j]*L[1][j]/r2[it]/r2[it];
			}
			E_k_r = E_k_r/0.00138; // in K
			E_k_tr = E_k_tr/0.00138; // in K
			E_k = E_k_r + E_k_tr;
			E_k0 = E_k;

			printf("Initial: U+E= %lf K, U= %lf K, E= %lf,  Er = %lf K, Etr = %lf K\n", U0+E_k0, U0, E_k0, E_k_r,E_k_tr);

			double R0s = sqrt(R0*R0 + b*b), R=R0s;
			while(R<=R0s){
			
				// Move molecules centers
				for(int i=0;i<2;i++)
					for(int j=0;j<3;j++) xcm[i][j] += vcm[i][j]*dt;

				// Rotate molecules
				double Ls[2] = {sqrt(model[it][mn].scalprod(L[0],L[0])),sqrt(model[it][mn].scalprod(L[1],L[1]))};
				double w[3], ang, newr[2][3];

				double eps_const = 0.000001;

				if (Ls[0] > eps_const) {
					for(int j=0;j<3;j++){
						w[j] = L[0][j]/Ls[0];
					}
					ang = Ls[0]/r1[it]/r1[it]*dt;
					model[it][mn].Rotate(r[0],w,ang,newr[0]);
					for(int j=0;j<3;j++){
						r[0][j] = newr[0][j];
					}
				}

				if (Ls[1] > eps_const) {
					for(int j=0;j<3;j++){
						w[j] = L[1][j]/Ls[1];
					}
					ang = Ls[1]/r2[it]/r2[it]*dt;
					model[it][mn].Rotate(r[2],w,ang,newr[1]);
					for(int j=0;j<3;j++){
						r[2][j] = newr[1][j];
					}
				}

				for(int j=0;j<3;j++){
					r[1][j] = -r[0][j];
					r[3][j] = -r[2][j];
				}

				for(int j=0;j<3;j++)
				{
					x[0][j] = xcm[0][j] + r[0][j];
					x[1][j] = xcm[0][j] + r[1][j];
					x[2][j] = xcm[1][j] + r[2][j];
					x[3][j] = xcm[1][j] + r[3][j];
				}

				// Compute forces
				U = model[it][mn].ComputeEnergy(x,&f[0][0]);
				double M[4][3];

				for(int i=0;i<4;i++) model[it][mn].vectprod(r[i],f[i],M[i]);

				// Change velocities of molecules centers
				double fcm[3] = {0,0,0};
				for(int i=0;i<2;i++)
					for(int j=0;j<3;j++) fcm[j] += f[i][j];
				for(int j=0;j<3;j++){
						vcm[0][j] += fcm[j]*dt/(2*m1[it]);
						vcm[1][j] += -fcm[j]*dt/(2*m2[it]);
					}

				// Change angular momentum
				for(int j=0;j<3;j++){
					L[0][j] += (M[0][j] + M[1][j])*dt/m1[it]/2;
					L[1][j] += (M[2][j] + M[3][j])*dt/m2[it]/2;
				}

				// Statistics
				E_k = 0;
				E_k_tr = 0;
				E_k_r = 0;
				for(int j=0;j<3;j++)
				{
						E_k_tr += m1[it]*vcm[0][j]*vcm[0][j];
						E_k_tr += m2[it]*vcm[1][j]*vcm[1][j];
						E_k_r += m1[it]*L[0][j]*L[0][j]/r1[it]/r1[it];
						E_k_r += m2[it]*L[1][j]*L[1][j]/r2[it]/r2[it];
				}
				E_k_r = E_k_r/0.00138;
				E_k_tr = E_k_tr/0.00138;
				E_k = E_k_r + E_k_tr;

				if (U+E_k -U0-E_k0 > FE_max) FE_max = U+E_k -U0-E_k0;
				if (U+E_k -U0-E_k0 < FE_min) FE_min = U+E_k -U0-E_k0;

				double Rv[3];
				for(int j=0;j<3;j++) Rv[j] = xcm[1][j]-xcm[0][j];
				R = sqrt(model[it][mn].scalprod(Rv,Rv));
				if (R<R_min) R_min = R;
			}

			comptime[it][mn] += (double)clock();

			printf("Final: U+E = %lf K, U = %lf K, E = %lf K, Er = %lf K, Etr = %lf K, \nR=%lf, R_min=%lf \n", U+E_k, U, E_k,E_k_r,E_k_tr, R, R_min);
			printf("Energy shift = %e K, shift range (%lf K, %lf K) \n",  U + E_k - (U0 + E_k0), FE_min, FE_max);
		}
	}

	printf("\nComputation time:");
	for(int it=0;it<3; it++)
	{
		printf("\n---  Pair %s, \tModel 1: %e, \tModel 2: %e, \tRelative: %lf",pair[it],comptime[it][0],comptime[it][1],comptime[it][1]/comptime[it][0]);
	}
	
	return 0;
}

